%First plotting of example traces, Figure 2A
%Reorganize the data into a more palatable form
clear;clc;
sham = [5,6,7,8,11,13];
exp = [1,2,3,4,9,10,12];
%Load the data
load('detectionBehaviorData.mat')
frequency = 1; %1 is 8 kHz, 2 is 32 kHz

noiseCount = 1;
shamCount = 1;

for mouse = 1:13 %Loop through mice

%Get the list of intensities in baseline
intensities = [];
for day = 1:length(Base.func{mouse})
    intensities = [intensities Base.trialType{mouse}{day}];  
end
allIntensities = sort(unique(intensities));
allHits = zeros(length(Base.func{mouse}),length(allIntensities));
allNTrials = zeros(length(Base.func{mouse}),length(allIntensities));

%First grab and concatenate the baseline data
for day = 1:length(Base.func{mouse}) %Loop through the sessions
    
    %Get the data
    nTrials = Base.numTrial{mouse}{day}{frequency};
    hits = nTrials.*Base.func{mouse}{day}{frequency};
    intensities = Base.trialType{mouse}{day};        
    
    %Now loop through intensities and add data as needed
    for i = 1:length(intensities)
        idx = find(allIntensities == intensities(i));
        allHits(day,idx) = hits(i);
        allNTrials(day,idx) = nTrials(i);
    end
    
end

%Account for NaN's
allHits(isnan(allHits)) = 0;
%Now save the data
if ismember(mouse,exp)
    noiseData(noiseCount).baseIntensities = allIntensities;
    noiseData(noiseCount).baseNTrials = allNTrials;
    noiseData(noiseCount).baseHits = allHits;
    noiseData(noiseCount).baseThresh = Base.thresh{mouse}(:,frequency);
elseif ismember(mouse,sham)
    shamData(shamCount).baseIntensities = allIntensities;
    shamData(shamCount).baseNTrials = allNTrials;
    shamData(shamCount).baseHits = allHits;
    shamData(shamCount).baseThresh = Base.thresh{mouse}(:,frequency);
end



%%%%%%%%%
%Now do the same for the post exposure days
intensities = [];
for day = 1:length(Expose.func{mouse})
    intensities = [intensities Expose.trialType{mouse}{day}];  
end
allIntensities = sort(unique(intensities));
allHits = zeros(length(Expose.func{mouse}),length(allIntensities));
allNTrials = zeros(length(Expose.func{mouse}),length(allIntensities));

%First grab and concatenate the baseline data
for day = 1:length(Expose.func{mouse}) %Loop through the sessions
    
    %Get the data
    nTrials = Expose.numTrial{mouse}{day}{frequency};
    hits = nTrials.*Expose.func{mouse}{day}{frequency};
    intensities = Expose.trialType{mouse}{day};        
    
    %Now loop through intensities and add data as needed
    for i = 1:length(intensities)
        idx = find(allIntensities == intensities(i));
        allHits(day,idx) = hits(i);
        allNTrials(day,idx) = nTrials(i);
    end
    
end

%Account for NaN's
allHits(isnan(allHits)) = 0;

%Account for the fact that some mice are missing a day post exposure
%(missing day 10).
threshs = Expose.thresh{mouse}(:,frequency);
if length(Expose.func{mouse}) == 14
    allHits = [allHits(1:9,:); zeros(1,size(allHits,2)); allHits(10:end,:)];
    allNTrials = [allNTrials(1:9,:); zeros(1,size(allNTrials,2)); allNTrials(10:end,:)];
    threshs = [threshs(1:9,:); NaN; threshs(10:end,:)];
end

%Now save the data
if ismember(mouse,exp)
    noiseData(noiseCount).postIntensities = allIntensities;
    noiseData(noiseCount).postNTrials = allNTrials;
    noiseData(noiseCount).postHits = allHits;
    noiseData(noiseCount).postThresh = threshs;
    noiseCount = noiseCount + 1;
elseif ismember(mouse,sham)
    shamData(shamCount).postIntensities = allIntensities;
    shamData(shamCount).postNTrials = allNTrials;
    shamData(shamCount).postHits = allHits;
    shamData(shamCount).postThresh = threshs;
    shamCount = shamCount + 1;
end


end

clearvars -except noiseData shamData

for mouse = 3
clearvars func_base func_post
trialThreshold = 10;
groups = [3 3 3 4]; 
    
nTrials = nansum(noiseData(mouse).baseNTrials([1 3],:),1); %Baseline trials
nHits = nansum(noiseData(mouse).baseHits([1 3],:),1); %Baseline hits

%Only consider intensities with enough trials
index = nTrials >= trialThreshold;
nTrials = nTrials(index); nHits = nHits(index);
intensities = noiseData(mouse).baseIntensities(index);

%Get the D Prime values
FA = nHits(1)/nTrials(1);
hitRates = nHits(2:end)./nTrials(2:end);
FA = norminv(FA*.99 + 0.005);
dPrimes = norminv(hitRates.*.99 + 0.005); dPrimes = dPrimes - FA;
func_base = [(intensities(2:end) - nanmean(noiseData(mouse).baseThresh));dPrimes];
    

start = 0; stop = 2;
for g=1:length(groups)
    
    start = stop + 1; stop = stop + groups(g);
    nTrials = nansum(noiseData(mouse).postNTrials(start:stop,:),1); %Baseline trials
    nHits = nansum(noiseData(mouse).postHits(start:stop,:),1); %Baseline hits
    
    %Only consider intensities with enough trials
    index = nTrials >= trialThreshold;
    nTrials = nTrials(index); nHits = nHits(index);
    intensities = noiseData(mouse).postIntensities(index);
    
    %Get the D Prime values
    FA = nHits(1)/nTrials(1);
    hitRates = nHits(2:end)./nTrials(2:end);
    FA = norminv(FA*.99 + 0.005);
    dPrimes = norminv(hitRates.*.99 + 0.005); dPrimes = dPrimes - FA;
    func_post{g} = [(intensities(2:end) - nanmean(noiseData(mouse).postThresh(start:stop)));dPrimes];
    
end

colorScale = {[0.7 0 0] [1 0 0] [1 0.3 0.3] [1 0.6 0.6]};
figure
plot(func_base(1,:),func_base(2,:),'k--','LineWidth',2)
hold on

for i = 1:length(groups)
    plot(func_post{i}(1,:),func_post{i}(2,:),'Color',colorScale{i},'LineWidth',2)
    hold on
end
ylim([-0.5 4])
xlim([-20 25])
xlabel('Intensity Re:Threshold (dB SPL)')
ylabel('D Prime')
box off

end

for mouse = 3
clearvars func_base func_post
trialThreshold = 5;
groups = [3 3 3 6];
    
nTrials = nansum(shamData(mouse).baseNTrials([1:end],:),1); %Baseline trials
nHits = nansum(shamData(mouse).baseHits([1:end],:),1); %Baseline hits

%Only consider intensities with enough trials
index = nTrials >= trialThreshold;
nTrials = nTrials(index); nHits = nHits(index);
intensities = shamData(mouse).baseIntensities(index);

%Get the D Prime values
FA = nHits(1)/nTrials(1);
hitRates = nHits(2:end)./nTrials(2:end);
FA = norminv(FA*.99 + 0.005);
dPrimes = norminv(hitRates.*.99 + 0.005); dPrimes = dPrimes - FA;
func_base = [(intensities(2:end) - nanmean(shamData(mouse).baseThresh));dPrimes];
    

start = 0; stop = 0;
for g=1:length(groups)
    
    start = stop + 1; stop = stop + groups(g);
    nTrials = nansum(shamData(mouse).postNTrials(start:stop,:)); %Baseline trials
    nHits = nansum(shamData(mouse).postHits(start:stop,:)); %Baseline hits
    
    %Only consider intensities with enough trials
    index = nTrials >= trialThreshold;
    nTrials = nTrials(index); nHits = nHits(index);
    intensities = shamData(mouse).postIntensities(index);
    
    %Get the D Prime values
    FA = nHits(1)/nTrials(1);
    hitRates = nHits(2:end)./nTrials(2:end);
    FA = norminv(FA*.99 + 0.005);
    dPrimes = norminv(hitRates.*.99 + 0.005); dPrimes = dPrimes - FA;
    func_post{g} = [(intensities(2:end) - nanmean(shamData(mouse).postThresh(start:stop)));dPrimes];
    
end

colorScale = {[0.1 0.1 0.1] [0.3 0.3 0.3] [0.5 0.5 0.5] [0.7 0.7 0.7]};
figure
plot(func_base(1,:),func_base(2,:),'k--','LineWidth',2)
hold on

for i = 1:4
    plot(func_post{i}(1,:),func_post{i}(2,:),'Color',colorScale{i},'LineWidth',2)
    hold on
end
ylim([-0.5 4])
xlim([-20 25])
xlabel('Intensity Re:Threshold (dB SPL)')
ylabel('D Prime')
box off

end

%%
%Next is the plotting for Figure 2B looking at slope change.
clear;clc;
%Reorganize the data into a more palatable form
sham = [5,6,7,8,11,13];
exp = [1,2,3,4,9,10,12];
%Load the data
load('detectionBehaviorData.mat')
frequency = 1; %1 is 8 kHz, 2 is 32 kHz

noiseCount = 1;
shamCount = 1;

for mouse = 1:13 %Loop through mice

%Get the list of intensities in baseline
intensities = [];
for day = 1:length(Base.func{mouse})
    intensities = [intensities Base.trialType{mouse}{day}];  
end
allIntensities = sort(unique(intensities));
allHits = zeros(length(Base.func{mouse}),length(allIntensities));
allNTrials = zeros(length(Base.func{mouse}),length(allIntensities));

%First grab and concatenate the baseline data
for day = 1:length(Base.func{mouse}) %Loop through the sessions
    
    %Get the data
    nTrials = Base.numTrial{mouse}{day}{frequency};
    hits = nTrials.*Base.func{mouse}{day}{frequency};
    intensities = Base.trialType{mouse}{day};        
    
    %Now loop through intensities and add data as needed
    for i = 1:length(intensities)
        idx = find(allIntensities == intensities(i));
        allHits(day,idx) = hits(i);
        allNTrials(day,idx) = nTrials(i);
    end
    
end

%Account for NaN's
allHits(isnan(allHits)) = 0;
%Now save the data
if ismember(mouse,exp)
    noiseData(noiseCount).baseIntensities = allIntensities;
    noiseData(noiseCount).baseNTrials = allNTrials;
    noiseData(noiseCount).baseHits = allHits;
    noiseData(noiseCount).baseThresh = Base.thresh{mouse}(:,frequency);
elseif ismember(mouse,sham)
    shamData(shamCount).baseIntensities = allIntensities;
    shamData(shamCount).baseNTrials = allNTrials;
    shamData(shamCount).baseHits = allHits;
    shamData(shamCount).baseThresh = Base.thresh{mouse}(:,frequency);
end



%%%%%%%%%
%Now do the same for the post exposure days
intensities = [];
for day = 1:length(Expose.func{mouse})
    intensities = [intensities Expose.trialType{mouse}{day}];  
end
allIntensities = sort(unique(intensities));
allHits = zeros(length(Expose.func{mouse}),length(allIntensities));
allNTrials = zeros(length(Expose.func{mouse}),length(allIntensities));

%First grab and concatenate the baseline data
for day = 1:length(Expose.func{mouse}) %Loop through the sessions
    
    %Get the data
    nTrials = Expose.numTrial{mouse}{day}{frequency};
    hits = nTrials.*Expose.func{mouse}{day}{frequency};
    intensities = Expose.trialType{mouse}{day};        
    
    %Now loop through intensities and add data as needed
    for i = 1:length(intensities)
        idx = find(allIntensities == intensities(i));
        allHits(day,idx) = hits(i);
        allNTrials(day,idx) = nTrials(i);
    end
    
end

%Account for NaN's
allHits(isnan(allHits)) = 0;

%Account for the fact that some mice are missing a day post exposure
%(missing day 10).
threshs = Expose.thresh{mouse}(:,frequency);
if length(Expose.func{mouse}) == 14
    allHits = [allHits(1:9,:); zeros(1,size(allHits,2)); allHits(10:end,:)];
    allNTrials = [allNTrials(1:9,:); zeros(1,size(allNTrials,2)); allNTrials(10:end,:)];
    threshs = [threshs(1:9,:); NaN; threshs(10:end,:)];
end

%Now save the data
if ismember(mouse,exp)
    noiseData(noiseCount).postIntensities = allIntensities;
    noiseData(noiseCount).postNTrials = allNTrials;
    noiseData(noiseCount).postHits = allHits;
    noiseData(noiseCount).postThresh = threshs;
    noiseCount = noiseCount + 1;
elseif ismember(mouse,sham)
    shamData(shamCount).postIntensities = allIntensities;
    shamData(shamCount).postNTrials = allNTrials;
    shamData(shamCount).postHits = allHits;
    shamData(shamCount).postThresh = threshs;
    shamCount = shamCount + 1;
end


end

clearvars -except noiseData shamData

%Now for each mouse let's examine the change in the slope of the d Prime
%curve pre/post exposure.

trialThreshold = 7; %Minimum number of trials needed to be kept
saturationThreshold = 0.9; %To find the range to evaluate the d Prime (slightly above 90%)
groups = [3 3 3 6];

for mouse = 1:length(noiseData) %Loop through mice
    
    %%%%%%%%%%%
    %First evaluate this mouse in baseline
    %%%%%%%%%%%
    nTrials = sum(noiseData(mouse).baseNTrials); %Baseline trials
    nHits = sum(noiseData(mouse).baseHits); %Baseline hits
    
    %Only consider intensities with enough trials
    index = nTrials >= trialThreshold;
    nTrials = nTrials(index); nHits = nHits(index);
    intensities = noiseData(mouse).baseIntensities(index);
    
    %Now get the d Prime over a useful range
    lastInt = find(nHits./nTrials > saturationThreshold, 1, 'first'); %Account for saturation
    if isempty(lastInt)
        lastInt = length(nTrials);
    end
    
    FA = nHits(1)/nTrials(1);
    hitRates = nHits(2:lastInt)./nTrials(2:lastInt);
    FA = norminv(FA*.99 + 0.005);
    dPrimes = norminv(hitRates.*.99 + 0.005); dPrimes = dPrimes - FA;
    
    %Now get the slope measurement in baseline
    noiseSlope(mouse,1) = mean(diff(dPrimes))/(intensities(lastInt) - intensities(2));
    
    %%%%%%%%%%%
    %Next evaluate the mouse in post exposure period
    %%%%%%%%%%%
    start = 0; stop = 0;
    for g = 1:length(groups)
        
        start = stop + 1; stop = stop + groups(g);
        nTrials = nansum(noiseData(mouse).postNTrials(start:stop,:),1); %Baseline trials
        nHits = nansum(noiseData(mouse).postHits(start:stop,:),1); %Baseline hits

        %Only consider intensities with enough trials
        index = nTrials >= trialThreshold;
        nTrials = nTrials(index); nHits = nHits(index);
        intensities = noiseData(mouse).postIntensities(index);
        
        %Account for mice that lost a day
        if isempty(nTrials)
            noiseSlope(mouse,g+1) = NaN;
            continue
        end

        %Now get the d Prime over a useful range
        lastInt = find(nHits./nTrials > saturationThreshold, 1, 'first'); %Account for saturation
        if isempty(lastInt)
            lastInt = length(nTrials);
        end

        FA = nHits(1)/nTrials(1);
        hitRates = nHits(2:lastInt)./nTrials(2:lastInt);
        FA = norminv(FA*.99 + 0.005);
        dPrimes = norminv(hitRates.*.99 + 0.005); dPrimes = dPrimes - FA;

        %Now get the slope measurement in baseline
        noiseSlope(mouse,g+1) = mean(diff(dPrimes))/(intensities(lastInt) - intensities(2));
    end
end



%%%%%%%%%%%%%%%%%%%%%%%%%%
%Do same as above but now for the sham exposed mice
%%%%%%%%%%%%%%%%%%%%%%%%%%

for mouse = 1:length(shamData) %Loop through mice
    
    %%%%%%%%%%%
    %First evaluate this mouse in baseline
    %%%%%%%%%%%
    nTrials = sum(shamData(mouse).baseNTrials); %Baseline trials
    nHits = sum(shamData(mouse).baseHits); %Baseline hits
    
    %Only consider intensities with enough trials
    index = nTrials >= trialThreshold;
    nTrials = nTrials(index); nHits = nHits(index);
    intensities = shamData(mouse).baseIntensities(index);
    
    %Now get the d Prime over a useful range
    lastInt = find(nHits./nTrials > saturationThreshold, 1, 'first'); %Account for saturation
    if isempty(lastInt)
        lastInt = length(nTrials);
    end
    
    FA = nHits(1)/nTrials(1);
    hitRates = nHits(2:lastInt)./nTrials(2:lastInt);
    FA = norminv(FA*.99 + 0.005);
    dPrimes = norminv(hitRates.*.99 + 0.005); dPrimes = dPrimes - FA;
    
    %Now get the slope measurement in baseline
    shamSlope(mouse,1) = mean(diff(dPrimes))/(intensities(lastInt) - intensities(2));
    
    %%%%%%%%%%%
    %Next evaluate the mouse in post exposure period
    %%%%%%%%%%%
    start = 0; stop = 0;
    for g = 1:length(groups)
        
        start = stop + 1; stop = stop + groups(g);
        nTrials = nansum(shamData(mouse).postNTrials(start:stop,:),1); %Baseline trials
        nHits = nansum(shamData(mouse).postHits(start:stop,:),1); %Baseline hits

        %Only consider intensities with enough trials
        index = nTrials >= trialThreshold;
        nTrials = nTrials(index); nHits = nHits(index);
        intensities = shamData(mouse).postIntensities(index);
        
        %Account for mice that lost a day
        if length(nTrials) < 3
            shamSlope(mouse,g+1) = NaN;
            continue
        end

        %Now get the d Prime over a useful range
        lastInt = find(nHits./nTrials > saturationThreshold, 1, 'first'); %Account for saturation
        if isempty(lastInt)
            lastInt = length(nTrials);
        end

        FA = nHits(1)/nTrials(1);
        hitRates = nHits(2:lastInt)./nTrials(2:lastInt);
        FA = norminv(FA*.99 + 0.005);
        dPrimes = norminv(hitRates.*.99 + 0.005); dPrimes = dPrimes - FA;

        %Now get the slope measurement in baseline
        shamSlope(mouse,g+1) = mean(diff(dPrimes))/(intensities(lastInt) - intensities(2));
    end
end

%Now make everything relative to baseline
shamSlope = shamSlope./shamSlope(:,1);
noiseSlope = noiseSlope./noiseSlope(:,1);

shamMean = nanmean(shamSlope,1);
noiseMean = nanmean(noiseSlope,1);
shamErr = nanstd(shamSlope,[],1)./sqrt(6);
noiseErr = nanstd(noiseSlope,[],1)./sqrt(7);

figure
shadedErrorBar([1:(length(groups)+1)],shamMean,shamErr,'lineprops',{'Color','k'})
hold on
p1 = plot([1:(length(groups)+1)],shamMean,'k');
shadedErrorBar([1:(length(groups)+1)],noiseMean,noiseErr,'lineprops',{'Color','r'})
p2 = plot([1:(length(groups)+1)],noiseMean,'r');
box off
ylabel('Change in Slope (dPrime/dB SPL) Re:Baseline')
xlabel('Days Re:Exposure')
ax = gca; ax.XTick = 1:5;
ax.XTickLabel = {'Baseline' 'Days 0-2' 'Days 3-5' 'Days 6-8' 'Days 10+'};
ax.XTickLabelRotation = -45;
legend({'Sham' 'Noise'})
ylim([0 4])
legend([p1 p2],{'Sham' 'Trauma'})

%%
%Also the statistics for Figure 2B

%First we will perform a repeated measures ANOVA on the slope data.
clear;clc;
%Reorganize the data into a more palatable form
sham = [5,6,7,8,11,13];
exp = [1,2,3,4,9,10,12];
%Load the data
load('detectionBehaviorData.mat')
frequency = 1; %1 is 8 kHz, 2 is 32 kHz

noiseCount = 1;
shamCount = 1;

for mouse = 1:13 %Loop through mice

%Get the list of intensities in baseline
intensities = [];
for day = 1:length(Base.func{mouse})
    intensities = [intensities Base.trialType{mouse}{day}];  
end
allIntensities = sort(unique(intensities));
allHits = zeros(length(Base.func{mouse}),length(allIntensities));
allNTrials = zeros(length(Base.func{mouse}),length(allIntensities));

%First grab and concatenate the baseline data
for day = 1:length(Base.func{mouse}) %Loop through the sessions
    
    %Get the data
    nTrials = Base.numTrial{mouse}{day}{frequency};
    hits = nTrials.*Base.func{mouse}{day}{frequency};
    intensities = Base.trialType{mouse}{day};        
    
    %Now loop through intensities and add data as needed
    for i = 1:length(intensities)
        idx = find(allIntensities == intensities(i));
        allHits(day,idx) = hits(i);
        allNTrials(day,idx) = nTrials(i);
    end
    
end

%Account for NaN's
allHits(isnan(allHits)) = 0;
%Now save the data
if ismember(mouse,exp)
    noiseData(noiseCount).baseIntensities = allIntensities;
    noiseData(noiseCount).baseNTrials = allNTrials;
    noiseData(noiseCount).baseHits = allHits;
    noiseData(noiseCount).baseThresh = Base.thresh{mouse}(:,frequency);
elseif ismember(mouse,sham)
    shamData(shamCount).baseIntensities = allIntensities;
    shamData(shamCount).baseNTrials = allNTrials;
    shamData(shamCount).baseHits = allHits;
    shamData(shamCount).baseThresh = Base.thresh{mouse}(:,frequency);
end



%%%%%%%%%
%Now do the same for the post exposure days
intensities = [];
for day = 1:length(Expose.func{mouse})
    intensities = [intensities Expose.trialType{mouse}{day}];  
end
allIntensities = sort(unique(intensities));
allHits = zeros(length(Expose.func{mouse}),length(allIntensities));
allNTrials = zeros(length(Expose.func{mouse}),length(allIntensities));

%First grab and concatenate the baseline data
for day = 1:length(Expose.func{mouse}) %Loop through the sessions
    
    %Get the data
    nTrials = Expose.numTrial{mouse}{day}{frequency};
    hits = nTrials.*Expose.func{mouse}{day}{frequency};
    intensities = Expose.trialType{mouse}{day};        
    
    %Now loop through intensities and add data as needed
    for i = 1:length(intensities)
        idx = find(allIntensities == intensities(i));
        allHits(day,idx) = hits(i);
        allNTrials(day,idx) = nTrials(i);
    end
    
end

%Account for NaN's
allHits(isnan(allHits)) = 0;

%Account for the fact that some mice are missing a day post exposure
%(missing day 10).
threshs = Expose.thresh{mouse}(:,frequency);
if length(Expose.func{mouse}) == 14
    allHits = [allHits(1:9,:); zeros(1,size(allHits,2)); allHits(10:end,:)];
    allNTrials = [allNTrials(1:9,:); zeros(1,size(allNTrials,2)); allNTrials(10:end,:)];
    threshs = [threshs(1:9,:); NaN; threshs(10:end,:)];
end

%Now save the data
if ismember(mouse,exp)
    noiseData(noiseCount).postIntensities = allIntensities;
    noiseData(noiseCount).postNTrials = allNTrials;
    noiseData(noiseCount).postHits = allHits;
    noiseData(noiseCount).postThresh = threshs;
    noiseCount = noiseCount + 1;
elseif ismember(mouse,sham)
    shamData(shamCount).postIntensities = allIntensities;
    shamData(shamCount).postNTrials = allNTrials;
    shamData(shamCount).postHits = allHits;
    shamData(shamCount).postThresh = threshs;
    shamCount = shamCount + 1;
end


end

clearvars -except noiseData shamData

%Now for each mouse let's examine the change in the slope of the d Prime
%curve pre/post exposure.

trialThreshold = 7; %Minimum number of trials needed to be kept
saturationThreshold = 0.9; %To find the range to evaluate the d Prime (slightly above 90%)
groups = [3 3 3 6];

for mouse = 1:length(noiseData) %Loop through mice
    
    %%%%%%%%%%%
    %First evaluate this mouse in baseline
    %%%%%%%%%%%
    nTrials = sum(noiseData(mouse).baseNTrials); %Baseline trials
    nHits = sum(noiseData(mouse).baseHits); %Baseline hits
    
    %Only consider intensities with enough trials
    index = nTrials >= trialThreshold;
    nTrials = nTrials(index); nHits = nHits(index);
    intensities = noiseData(mouse).baseIntensities(index);
    
    %Now get the d Prime over a useful range
    lastInt = find(nHits./nTrials > saturationThreshold, 1, 'first'); %Account for saturation
    if isempty(lastInt)
        lastInt = length(nTrials);
    end
    
    FA = nHits(1)/nTrials(1);
    hitRates = nHits(2:lastInt)./nTrials(2:lastInt);
    FA = norminv(FA*.99 + 0.005);
    dPrimes = norminv(hitRates.*.99 + 0.005); dPrimes = dPrimes - FA;
    
    %Now get the slope measurement in baseline
    noiseSlope(mouse,1) = mean(diff(dPrimes))/(intensities(lastInt) - intensities(2));
    %noiseSlope(mouse,1) = (dPrimes(end) - dPrimes(1))/(intensities(lastInt) - intensities(2));
    
    %%%%%%%%%%%
    %Next evaluate the mouse in post exposure period
    %%%%%%%%%%%
    start = 0; stop = 0;
    for g = 1:length(groups)
        
        start = stop + 1; stop = stop + groups(g);
        nTrials = nansum(noiseData(mouse).postNTrials(start:stop,:),1); %Baseline trials
        nHits = nansum(noiseData(mouse).postHits(start:stop,:),1); %Baseline hits

        %Only consider intensities with enough trials
        index = nTrials >= trialThreshold;
        nTrials = nTrials(index); nHits = nHits(index);
        intensities = noiseData(mouse).postIntensities(index);
        
        %Account for mice that lost a day
        if isempty(nTrials)
            noiseSlope(mouse,g+1) = NaN;
            continue
        end

        %Now get the d Prime over a useful range
        lastInt = find(nHits./nTrials > saturationThreshold, 1, 'first'); %Account for saturation
        if isempty(lastInt)
            lastInt = length(nTrials);
        end

        FA = nHits(1)/nTrials(1);
        hitRates = nHits(2:lastInt)./nTrials(2:lastInt);
        FA = norminv(FA*.99 + 0.005);
        dPrimes = norminv(hitRates.*.99 + 0.005); dPrimes = dPrimes - FA;

        %Now get the slope measurement in baseline
        noiseSlope(mouse,g+1) = mean(diff(dPrimes))/(intensities(lastInt) - intensities(2));
        %noiseSlope(mouse,g+1) = (dPrimes(end) - dPrimes(1))/(intensities(lastInt) - intensities(2));
    end
end



%%%%%%%%%%%%%%%%%%%%%%%%%%
%Do same as above but now for the sham exposed mice
%%%%%%%%%%%%%%%%%%%%%%%%%%

for mouse = 1:length(shamData) %Loop through mice
    
    %%%%%%%%%%%
    %First evaluate this mouse in baseline
    %%%%%%%%%%%
    nTrials = sum(shamData(mouse).baseNTrials); %Baseline trials
    nHits = sum(shamData(mouse).baseHits); %Baseline hits
    
    %Only consider intensities with enough trials
    index = nTrials >= trialThreshold;
    nTrials = nTrials(index); nHits = nHits(index);
    intensities = shamData(mouse).baseIntensities(index);
    
    %Now get the d Prime over a useful range
    lastInt = find(nHits./nTrials > saturationThreshold, 1, 'first'); %Account for saturation
    if isempty(lastInt)
        lastInt = length(nTrials);
    end
    
    FA = nHits(1)/nTrials(1);
    hitRates = nHits(2:lastInt)./nTrials(2:lastInt);
    FA = norminv(FA*.99 + 0.005);
    dPrimes = norminv(hitRates.*.99 + 0.005); dPrimes = dPrimes - FA;
    
    %Now get the slope measurement in baseline
    shamSlope(mouse,1) = mean(diff(dPrimes))/(intensities(lastInt) - intensities(2));
    %shamSlope(mouse,1) = (dPrimes(end) - dPrimes(1))/(intensities(lastInt) - intensities(2));
    
    %%%%%%%%%%%
    %Next evaluate the mouse in post exposure period
    %%%%%%%%%%%
    start = 0; stop = 0;
    for g = 1:length(groups)
        
        start = stop + 1; stop = stop + groups(g);
        nTrials = nansum(shamData(mouse).postNTrials(start:stop,:),1); %Baseline trials
        nHits = nansum(shamData(mouse).postHits(start:stop,:),1); %Baseline hits

        %Only consider intensities with enough trials
        index = nTrials >= trialThreshold;
        nTrials = nTrials(index); nHits = nHits(index);
        intensities = shamData(mouse).postIntensities(index);
        
        %Account for mice that lost a day
        if length(nTrials) < 3
            shamSlope(mouse,g+1) = NaN;
            continue
        end

        %Now get the d Prime over a useful range
        lastInt = find(nHits./nTrials > saturationThreshold, 1, 'first'); %Account for saturation
        if isempty(lastInt)
            lastInt = length(nTrials);
        end

        FA = nHits(1)/nTrials(1);
        hitRates = nHits(2:lastInt)./nTrials(2:lastInt);
        FA = norminv(FA*.99 + 0.005);
        dPrimes = norminv(hitRates.*.99 + 0.005); dPrimes = dPrimes - FA;

        %Now get the slope measurement in baseline
        shamSlope(mouse,g+1) = mean(diff(dPrimes))/(intensities(lastInt) - intensities(2));
        %shamSlope(mouse,g+1) = (dPrimes(end) - dPrimes(1))/(intensities(lastInt) - intensities(2));
    end
end

%Now make everything relative to baseline
shamSlope = shamSlope./shamSlope(:,1);
noiseSlope = noiseSlope./noiseSlope(:,1);
clearvars -except shamSlope noiseSlope

%Now we want to format this into a table with the correct labels
dataTable = [shamSlope(:,2:end); noiseSlope(:,2:end)];
for i = 1:4
    v = strcat('V',num2str(i));
    varNames{i,1} = v;
end
dataTable = array2table(dataTable, 'VariableNames', varNames);

%Add in the exposure group variable
for i = 1:13
    if ismember(i,[7:13])
        expGroup{i,1} = 'Noise';
    else
        expGroup{i,1} = 'Sham';
    end
end
dataTable.expGroup = expGroup;

%Next format the table containing the within subjects factors (the repeated
%measures).
factorNames = {'Day'};

dayLabels = [1:4]';
dayLabels = arrayfun(@num2str, dayLabels, 'UniformOutput', 0);
withinTable = table(dayLabels,'VariableNames',factorNames);

%Now that we have the tables set up, run the anova model
rmSlope = fitrm(dataTable, 'V1-V4~expGroup','WithinDesign',withinTable);
[rAnovaResults] = ranova(rmSlope, 'WithinModel','Day');
rAnovaResults

%%
%Now move to the optogenetic detection task, first looking at the change in
%detection threshold - Figure 2D.
clear;clc;

%First group the data
groups = {[-5;-4;-3;-2;-1] [0:3]};
load('optoDetectionData.mat')

%Loop through mice
for mouse = 1:11

%Get the list of powers and initiate
allIntensities = vertcat(allOptoData(mouse).soundData.intensities);
allIntensities = unique(allIntensities);
groupedData(mouse).intensities = allIntensities;
groupedData(mouse).baseline = zeros(2,length(allIntensities));
groupedData(mouse).group1 = zeros(2,length(allIntensities));

%Get the list of days tested for this mouse
daysTested = vertcat(allOptoData(mouse).soundData.dayNum);

%Loop through the groups
for groupID = 1:length(groups)
    
    %Now for a group, determine which days to take.
    dayIdxs = find(ismember(daysTested,groups{groupID}));

    for day = dayIdxs' %Loop through days under this grouping
        
        %Grab the info from this day
        intensities = allOptoData(mouse).soundData(day).intensities;
        nTrials = allOptoData(mouse).soundData(day).nTrials;
        hits = allOptoData(mouse).soundData(day).hits;
        
        %Now add accordingly to the data structure
        count = 1;
        for mW = intensities' %Loop through tested powers
            
            pIdx = find(allIntensities == mW); %Get power index
            
            %Assign based on which group
            if groupID == 1
                groupedData(mouse).baseline(1,pIdx) = groupedData(mouse).baseline(1,pIdx) + nTrials(count);
                groupedData(mouse).baseline(2,pIdx) = groupedData(mouse).baseline(2,pIdx) + hits(count);
            elseif groupID == 2
                groupedData(mouse).group1(1,pIdx) = groupedData(mouse).group1(1,pIdx) + nTrials(count);
                groupedData(mouse).group1(2,pIdx) = groupedData(mouse).group1(2,pIdx) + hits(count);                                                  
            end
            
            count = count + 1;
        end
        
        
    end
    
end


end

%

groupNames = fields(groupedData); groupNames = groupNames(2:end);
targetDPrime = 1;
targetHit = 0.707;

%Loop through each mouse
for mouse = [1:4 6:11]

    for day = 1:length(groupNames) %Loop through grouped days

        %Get data from this day
        func = groupedData(mouse).(groupNames{day});
        func = func(2,:)./func(1,:);
        intensities = groupedData(mouse).intensities(2:end);

        %Get d Primes
        FA = norminv(func(1)*.99 + 0.005);
        dPrimes = norminv(func(2:end).*.99 + 0.005); dPrimes = dPrimes - FA;
        
        %Let's also derive the hit rate needed for a d prime of 1 
%         targetHit = normcdf(targetDPrime + FA)/0.99 - 0.005;
        
        %Get the first and last power to use for the slope calculation
        func = func(2:end);        

        %Now finally get a measurement of threshold using the power level
        %that will theoretically give a d prime of 1
        xFit = intensities(1):0.05:intensities(end);
        [logitCoef,~] = glmfit(intensities,func','binomial','logit');
        yFit = glmval(logitCoef,xFit,'logit');
        
        [~,I] = min(abs(yFit - targetHit)); threshold(day) = xFit(I);
        

    end
    
    %Store the data
    summaryData(mouse).threshold = threshold;
    
end

%Next make the threshold relative to baseline
for i = [1:4 6:11]
    summaryData(i).threshold = summaryData(i).threshold - summaryData(i).threshold(1);
end

%Finally plot the threshold shift data
thresh_noise = []; thresh_sham = [];
for i = 1:11
    if ismember(i,[6 9 11])
        thresh_sham = [thresh_sham; summaryData(i).threshold];
    elseif ismember(i,[1:4 7:8]) %Discluding first test mouse (not age-matched), mouse with no thresh shift)
        thresh_noise = [thresh_noise; summaryData(i).threshold];
    end
    hold on
end
%Add plots of the means
mean_sham = nanmean(thresh_sham,1);
mean_noise = nanmean(thresh_noise,1);


%Can also just plot as a bar plot
figure
bar(nonzeros([mean_sham,mean_noise]),0.35)
hold on
plotSpread({nonzeros([summaryData([6 9 11]).threshold]) nonzeros([summaryData([1:4 7:8]).threshold])},...
    'categoryIdx',[ones(1,3) 2.*ones(1,6)])
ax = gca;
ax.XTickLabel = {'Sham' 'Trauma'};
ylabel('Threshold (dB SPL)')
box off

%%
%Now do the statistics for Figure 2D

%Now we'll do 2 sets of paired t tests for the two exposure groups
[h,p,ci,stats] = ttest(thresh_sham(:,1),thresh_sham(:,2));
p
stats.tstat

disp('---')

[h,p,ci,stats] = ttest(thresh_noise(:,1),thresh_noise(:,2));
p
stats.tstat

%%
%Next create the example plots for Figure 2E

%Now also get example d prime curves. Mouse 8 will be the example noise
%exposed mouse, and mouse 11 will be the example sham exposed mouse.
clear;clc;
groups = {[-3;-5;-1] [0:3] [4:10]};
load('optoDetectionData.mat')

%Loop through mice
for mouse = 1:10

%Get the list of powers and initiate
allPowers = vertcat(allOptoData(mouse).laserData.intensities);
allPowers = unique(allPowers);
groupedData(mouse).powers = allPowers;
groupedData(mouse).baseline = zeros(2,length(allPowers));
groupedData(mouse).group1 = zeros(2,length(allPowers));
groupedData(mouse).group2 = zeros(2,length(allPowers));

%Get the list of days tested for this mouse
daysTested = vertcat(allOptoData(mouse).laserData.dayNum);

%Loop through the groups
for groupID = 1:length(groups)
    
    %Now for a group, determine which days to take.
    dayIdxs = find(ismember(daysTested,groups{groupID}));

    for day = dayIdxs' %Loop through days under this grouping
        
        %Grab the info from this day
        powers = allOptoData(mouse).laserData(day).intensities;
        nTrials = allOptoData(mouse).laserData(day).nTrials;
        hits = allOptoData(mouse).laserData(day).hits;
        
        %Now add accordingly to the data structure
        count = 1;
        for mW = powers' %Loop through tested powers
            
            pIdx = find(allPowers == mW); %Get power index
            
            %Assign based on which group
            if groupID == 1
                groupedData(mouse).baseline(1,pIdx) = groupedData(mouse).baseline(1,pIdx) + nTrials(count);
                groupedData(mouse).baseline(2,pIdx) = groupedData(mouse).baseline(2,pIdx) + hits(count);
            elseif groupID == 2
                groupedData(mouse).group1(1,pIdx) = groupedData(mouse).group1(1,pIdx) + nTrials(count);
                groupedData(mouse).group1(2,pIdx) = groupedData(mouse).group1(2,pIdx) + hits(count);                
            elseif groupID == 3
                groupedData(mouse).group2(1,pIdx) = groupedData(mouse).group2(1,pIdx) + nTrials(count);
                groupedData(mouse).group2(2,pIdx) = groupedData(mouse).group2(2,pIdx) + hits(count);                                   
            end
            
            count = count + 1;
        end
        
        
    end
    
end


end

%Now try plotting out the grouped data for each mouse
groupNames = fields(groupedData); groupNames = groupNames(2:end);
colors = {[0 0 0] [1 0 0] [1 0.6 0.6]};

%Loop through each mouse
for mouse = 8

    figure
    for day = 1:length(groupNames) %Loop through days

        %Get data from this day
        func = groupedData(mouse).(groupNames{day});
        func = func(2,:)./func(1,:);
        powers = groupedData(mouse).powers(2:end);

        %Get d Primes
        FA = norminv(func(1)*.99 + 0.005);
        dPrimes = norminv(func(2:end).*.99 + 0.005); dPrimes = dPrimes - FA;

        %Now plot the data
        if day == 1
            plot(powers,dPrimes,'--','Color',colors{day},'LineWidth',1.75)
        else
            plot(powers,dPrimes,'Color',colors{day},'LineWidth',1.75)
        end
        hold on

    end

    box off
    legend({'Baseline' 'Days 0-3' 'Days 4-10'})
    ylabel('d Prime')
    xlabel('Laser Power (dB mW)')
    ylim([-1 4.5])
    
    if ismember(mouse,[6 9 11])
        suptitle('Sham Exposed Mouse')
    else
        suptitle('Noise Exposed Mouse')
    end


end

clear;clc;
groups = {[-1;-3] [0 1 2] [6 10]}; %For sham mouse 11
load('optoDetectionData.mat')

%Loop through mice
for mouse = 1:11

%Get the list of powers and initiate
allPowers = vertcat(allOptoData(mouse).laserData.intensities);
allPowers = unique(allPowers);
groupedData(mouse).powers = allPowers;
groupedData(mouse).baseline = zeros(2,length(allPowers));
groupedData(mouse).group1 = zeros(2,length(allPowers));
groupedData(mouse).group2 = zeros(2,length(allPowers));

%Get the list of days tested for this mouse
daysTested = vertcat(allOptoData(mouse).laserData.dayNum);

%Loop through the groups
for groupID = 1:length(groups)
    
    %Now for a group, determine which days to take.
    dayIdxs = find(ismember(daysTested,groups{groupID}));

    for day = dayIdxs' %Loop through days under this grouping
        
        %Grab the info from this day
        powers = allOptoData(mouse).laserData(day).intensities;
        nTrials = allOptoData(mouse).laserData(day).nTrials;
        hits = allOptoData(mouse).laserData(day).hits;
        
        %Now add accordingly to the data structure
        count = 1;
        for mW = powers' %Loop through tested powers
            
            pIdx = find(allPowers == mW); %Get power index
            
            %Assign based on which group
            if groupID == 1
                groupedData(mouse).baseline(1,pIdx) = groupedData(mouse).baseline(1,pIdx) + nTrials(count);
                groupedData(mouse).baseline(2,pIdx) = groupedData(mouse).baseline(2,pIdx) + hits(count);
            elseif groupID == 2
                groupedData(mouse).group1(1,pIdx) = groupedData(mouse).group1(1,pIdx) + nTrials(count);
                groupedData(mouse).group1(2,pIdx) = groupedData(mouse).group1(2,pIdx) + hits(count);                
            elseif groupID == 3
                groupedData(mouse).group2(1,pIdx) = groupedData(mouse).group2(1,pIdx) + nTrials(count);
                groupedData(mouse).group2(2,pIdx) = groupedData(mouse).group2(2,pIdx) + hits(count);                                   
            end
            
            count = count + 1;
        end
        
        
    end
    
end


end

%Now try plotting out the grouped data for each mouse
groupNames = fields(groupedData); groupNames = groupNames(2:end);
colors = {[0 0 0] [1 0 0] [1 0.6 0.6]};

%Loop through each mouse
for mouse = 11

    figure
    for day = 1:length(groupNames) %Loop through days

        %Get data from this day
        func = groupedData(mouse).(groupNames{day});
        func = func(2,:)./func(1,:);
        powers = groupedData(mouse).powers(2:end);

        %Get d Primes
        FA = norminv(func(1)*.99 + 0.005);
        dPrimes = norminv(func(2:end).*.99 + 0.005); dPrimes = dPrimes - FA;

        %Now plot the data
        if day == 1
            plot(powers,dPrimes,'--','Color',colors{day},'LineWidth',1.75)
        else
            plot(powers,dPrimes,'Color',colors{day},'LineWidth',1.75)
        end
        hold on

    end

    box off
    legend({'Baseline' 'Days 0-3' 'Days 4-10'})
    ylabel('d Prime')
    xlabel('Laser Power (dB mW)')
    ylim([-1 4.5])
    
    if ismember(mouse,[6 9 11])
        suptitle('Sham Exposed Mouse')
    else
        suptitle('Noise Exposed Mouse')
    end


end

%%
%Finally let's do the analysis for change in sensitivity for Figure 2F

%Let's first start with looping through the mice and concatenating the data
%as necessary for each of the groups of days.

groups = {[-5;-4;-3;-2;-1] [0:3] [4:9]};
load('optoDetectionData.mat')

%Loop through mice
for mouse = 1:11

%Get the list of powers and initiate
allPowers = vertcat(allOptoData(mouse).laserData.intensities);
allPowers = unique(allPowers);
groupedData(mouse).powers = allPowers;
groupedData(mouse).baseline = zeros(2,length(allPowers));
groupedData(mouse).group1 = zeros(2,length(allPowers));
groupedData(mouse).group2 = zeros(2,length(allPowers));

%Get the list of days tested for this mouse
daysTested = vertcat(allOptoData(mouse).laserData.dayNum);

%Loop through the groups
for groupID = 1:length(groups)
    
    %Now for a group, determine which days to take.
    dayIdxs = find(ismember(daysTested,groups{groupID}));

    for day = dayIdxs' %Loop through days under this grouping
        
        %Grab the info from this day
        powers = allOptoData(mouse).laserData(day).intensities;
        nTrials = allOptoData(mouse).laserData(day).nTrials;
        hits = allOptoData(mouse).laserData(day).hits;
        
        %Now add accordingly to the data structure
        count = 1;
        for mW = powers' %Loop through tested powers
            
            pIdx = find(allPowers == mW); %Get power index
            
            %Assign based on which group
            if groupID == 1
                groupedData(mouse).baseline(1,pIdx) = groupedData(mouse).baseline(1,pIdx) + nTrials(count);
                groupedData(mouse).baseline(2,pIdx) = groupedData(mouse).baseline(2,pIdx) + hits(count);
            elseif groupID == 2
                groupedData(mouse).group1(1,pIdx) = groupedData(mouse).group1(1,pIdx) + nTrials(count);
                groupedData(mouse).group1(2,pIdx) = groupedData(mouse).group1(2,pIdx) + hits(count);                
            elseif groupID == 3
                groupedData(mouse).group2(1,pIdx) = groupedData(mouse).group2(1,pIdx) + nTrials(count);
                groupedData(mouse).group2(2,pIdx) = groupedData(mouse).group2(2,pIdx) + hits(count);                                   
            end
            
            count = count + 1;
        end
        
        
    end
    
end


end


%Now let's get the measure of the change in d Prime

groupNames = fields(groupedData); groupNames = groupNames(2:end);

%Loop through each mouse
for mouse = 1:11
    
    %Get the baseline function, needed for later calculation
    baseFunc = groupedData(mouse).baseline;
    baseFunc = baseFunc(2,:)./baseFunc(1,:);
    FA = norminv(baseFunc(1)*.99 + 0.005);
    baseDPrimes = norminv(baseFunc(2:end).*.99 + 0.005); baseDPrimes = baseDPrimes - FA;
    
    %Also get the interval that will be used for future calculations      
    baseFirstInt = find(baseDPrimes > 0, 1, 'first'); %Account for too many low powers
    if isempty(baseFirstInt)
        baseFirstInt = 1;
    end

    for day = 1:length(groupNames) %Loop through grouped days

        %Get data from this day
        func = groupedData(mouse).(groupNames{day});
        func = func(2,:)./func(1,:);
        powers = groupedData(mouse).powers(2:end);

        %Get d Primes
        FA = norminv(func(1)*.99 + 0.005);
        dPrimes = norminv(func(2:end).*.99 + 0.005); dPrimes = dPrimes - FA;
        
        %Next find the difference in the d prime curve at useful powers
        %compared to the detection in baseline
        dPrimeDiffs = dPrimes - baseDPrimes;
        avgDiff(day) = mean(dPrimeDiffs(1:(baseFirstInt+1)));

    end
    
    %Store the data
    summaryData(mouse).dPrimeDiff = avgDiff;
    
end


%Now plot out the above variable
figure
diff_noise = []; diff_sham = [];
for i = 1:11
    if ismember(i,[6 9 11])
        plot(summaryData(i).dPrimeDiff,'k--','LineWidth',0.5)
        diff_sham = [diff_sham; summaryData(i).dPrimeDiff];
    elseif ismember(i,[2:5 7:8]) %Discluding first test mouse (not age-matched), mouse with no thresh shift)
        plot(summaryData(i).dPrimeDiff,'r','LineWidth',0.5)
        diff_noise = [diff_noise; summaryData(i).dPrimeDiff];
    end
    hold on
end
%Add plots of the means
mean_sham = nanmean(diff_sham,1);
mean_noise = nanmean(diff_noise,1);
plot(mean_sham,'k','LineWidth',2.5)
plot(mean_noise,'r','LineWidth',2.5)

title('d'' Differences')
ylabel('Change in d Prime Re:Baseline')
ax = gca;
ax.XTick = [1 2 3 4 5];
ax.XTickLabel = {'Baseline' 'Days 0-3' 'Days 4+'};

%%
%Finally the stats for the above

clearvars -except diff_sham diff_noise

%Now set up the ANOVA
allValues = [diff_noise(:,2:3); diff_sham(:,2:3)];
    
%Now we want to format this into a table with the correct labels
for i = 1:2
    v = strcat('V',num2str(i));
    varNames{i,1} = v;
end

dataTable = array2table(allValues, 'VariableNames', varNames);

%Add in the exposure group variable
for i = 1:9
    if ismember(i,[1:6])
        expGroup{i,1} = 'Noise';
    else
        expGroup{i,1} = 'Sham';
    end
end
dataTable.expGroup = expGroup;

%Finally create a table that reflects the within subject factors, which
%here are frequency and time. 
factorNames = {'Day'};

dayLabels = [1:2]';
dayLabels = arrayfun(@num2str, dayLabels, 'UniformOutput', 0);

withinTable = table(dayLabels,'VariableNames',factorNames);

%Now that we have the tables set up, run the anova model
rmOPTO = fitrm(dataTable, 'V1-V2~expGroup','WithinDesign',withinTable);
[rAnovaResults] = ranova(rmOPTO, 'WithinModel','Day');
rAnovaResults
